-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates from conformer #338
base: main
Are you sure you want to change the base?
Conversation
Somewhere between pymatgen 2024.5.1 and 2024.8.9 they fixed space group 19 to be |
oh, I see, thank you for noticing this @ftherrien ! Do you think we should change |
@AlexandraVolokhova Good question. I just saw the same issue in a different PR and basically came to the same conclusion that you guys did. I think that we can change the file space_groups.yaml to fix the typo but we might want to put a comment in the corresponding test to say something like "If this test is the only one failing, you might have an older version of pymatgen in which there was a typo in the space group name". |
@carriepl thank you, that sounds like a good plan! I can make a separate PR for that and we can rebase ours once it is merged into main |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems overall good to me. I just had a comment and a question.
@@ -1449,7 +1449,6 @@ def fit_kde( | |||
bandwidth : float | |||
The bandwidth of the kernel. | |||
""" | |||
samples = torch2np(samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change because Sklearns supports tensortypes when fitting the KernelDensity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is because env.states2kde(states)
is always called before calling fit_kde
and torch2np happens there
] | ||
self._parents_policy_available = True | ||
# hacky way to check whether it is MoleculeGraph without importing MoleculeGraph here | ||
if hasattr(self.states_policy, "num_nodes"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That part makes me kinda uncomfortable. It feels like someone could create a new policy class and trigger this condition entirely by accident and get a really really weird bug. Maybe we could find a more reliable way to express this condition.
Can you tell me a bit more about why this is required? Why was the previous code not working in this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I don't really like it either but I don't know how to make it better. So, the problem is that for the new conformer env, states_policy are graphs, which are instances of a MoleculeGraph object, which exists only in the private repo, not here, so I cannot import it and check whether states_policy is the instance of this class. In the same time, I need to use a different way to collate states_policy if they are graphs, it is not possible to just create an empty tensor and fill it in as it is done here for ordinary states_policy
Hi @AlexandraVolokhova, as @ftherrien suggested the issue is the version mismatch, I was also surprised with this issue recently and figured out that the issue is a mismatch of versions when tested on the CI and locally. Two ways to fix:
|
Thank you @engmubarak48 ! I actually already fixed this issue in #339 and merged it to main. Not sure if I need to merge it to this PR as well |
Thanks @AlexandraVolokhova I should have asked you. I updated my PRs to sync with the main without downgrading the versions. And now that error disappears and tests run well. |
@@ -48,7 +55,7 @@ def __init__( | |||
# Call reset() to set initial state, done, n_actions | |||
self.reset() | |||
# Device | |||
self.device = set_device(device) | |||
self.set_device(set_device(device)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this change? Is this correct? :/
@@ -757,6 +775,15 @@ def traj2readable(self, traj=None): | |||
""" | |||
return str(traj).replace("(", "[").replace(")", "]").replace(",", "") | |||
|
|||
def states2kde( | |||
self, states: Union[List, TensorType["batch", "state_dim"]] | |||
) -> Union[List, npt.NDArray, TensorType["batch", "kde_dim"]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the return type is always npt.NDArray
isn't it?
Several updates from the conformer project:
python -m pytest ./tests/
black ./gflownet/
black ./tests/
isort --profile black ./gflownet/
isort --profile black ./tests/